package org.zjvis.datascience.common.algo;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.AlgEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

import java.util.Arrays;
import java.util.List;

/**
 * @description FP-Growth 模式检测模板类
 * @date 2021-12-24
 */
public class FPGrowthAlg extends BaseAlg {

    private final static Logger logger = LoggerFactory.getLogger("FPGrowthAlg");

    private static String TPL_FILENAME = "template/algo/fp_growth.json";

    private static String SQL_TPL_MADLIB = "SELECT * FROM \"%s\".\"fp_growth\"('%s', '%s', '%s', %s, %s)";
    private static String SQL_TPL_MADLIB_SAMPLE = "SELECT * FROM \"%s\".\"fp_growth\"('CREATE VIEW %s AS SELECT * from %s where \"%s\" <= %s', '%s', '%s', %s, %s)";
    private static String SQL_TPL_SPARK = "fp-growth -s %s -f %s -t %s -t2 %s -uk %d -idcol %s -mins %s -minc %s -minf %d";

    public void initTemplate(JSONObject data) {
        JSONArray jsonArray = getTemplateParamList(TPL_FILENAME);
        data.put("setParams", jsonArray);
        baseInitTemplate(data);
    }

    public FPGrowthAlg() {
        super(AlgEnum.FP_GROWTH.name(), SubTypeEnum.PATTERN_DETECTION.getVal(),
                SubTypeEnum.PATTERN_DETECTION.getDesc());
        this.maxParentNumber = 1;
    }

    private String getFPGrowthSql(String sourceTable, String outTable, String featureCols, int minFreq, float minSupport,
                                  String associationTable, long timeStamp, float confidence, String sampleTable) {
        if (StringUtils.isEmpty(featureCols)) {
            return StringUtils.EMPTY;
        }
        if (StringUtils.isNotEmpty(sampleTable)) {
            // 采样
            return String.format(SQL_TPL_MADLIB_SAMPLE, SqlTemplate.SCHEMA,
                    sampleTable, sourceTable, ID_COL, SAMPLE_NUMBER, outTable, featureCols, minFreq,
                    minSupport);
        } else {
            // 全量, 根据配置
            if (getEngine().isMadlib()) {
                return String
                        .format(SQL_TPL_MADLIB, SqlTemplate.SCHEMA, sourceTable, outTable, featureCols,
                                minFreq, minSupport);
            } else if (getEngine().isSpark()) {
                return String.format(SQL_TPL_SPARK, sourceTable, featureCols,
                        associationTable, outTable, timeStamp, ID_COL, minSupport, confidence, minFreq);
            }
        }
        return StringUtils.EMPTY;
    }

    public String initSql(JSONObject json, List<SqlHelper> sqlHelpers, long timeStamp, String engineName) {
        this.engineName = engineName;
        String sourceTable = json.getString("source_table");
        sourceTable = ToolUtil.alignTableName(sourceTable, timeStamp);
        String outTable = json.getString("out_table_rename");
        String associationTable = ToolUtil.alignTableName(outTable + "rule_", timeStamp);
        outTable = ToolUtil.alignTableName(outTable, timeStamp);
        int minFreq = json.getInteger("min_frequency");
        float minSupport = json.getFloat("min_support");
        JSONArray features = json.getJSONArray("feature_cols");
        float minConfidence = json.getFloat("min_confidence");
        String featureCols = this.getFeatureColsStr(features);
        String sampleTable = "";
        if (!json.containsKey("isSample") || json.getString("isSample").equals("SUCCESS") || json
                .getString("isSample").equals("FAIL")) {
            sampleTable = outTable.replace("solid_", "view_");
            json.put("isSample", "CREATE");
        }
        String sql = this.getFPGrowthSql(sourceTable, outTable, featureCols, minFreq, minSupport,
                associationTable, timeStamp, minConfidence, sampleTable);
        logger.debug("sql={}", sql);
        return sql;
    }

    public void defineOutput(TaskVO vo) {
        JSONObject jsonObject = vo.getData();
        String outTablePrefix = jsonObject.getString("out_table");
        String tableName = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix, vo.getPipelineId(), vo.getId());
        jsonObject.put("out_table_rename", tableName);
        JSONArray input = jsonObject.getJSONArray("input");
        if (input == null || input.size() == 0) {
            logger.warn("input is empty");
            return;
        }
        this.checkBoxSelectFilter(jsonObject, "", FEATURE_COLS);
        this.supplementForCheckbox(jsonObject, TPL_FILENAME, 1, vo);
        jsonObject.put("source_table", input.getJSONObject(0).getString("tableName"));
        String[] cols = new String[]{"id", "pattern", "frequency", "support"};
        String[] types = new String[]{"integer", "array", "integer", "double precision"};

        JSONArray outputColTypes = new JSONArray();
        JSONArray outputCols = new JSONArray();

        outputCols.addAll(Arrays.asList(cols));
        outputColTypes.addAll(Arrays.asList(types));

        JSONArray output = new JSONArray();
        JSONObject outItem = new JSONObject();
        outItem.put("tableName", tableName);
        outItem.put("tableCols", outputCols);
        outItem.put("nodeName", vo.getName() == null ? AlgEnum.FP_GROWTH.toString() : vo.getName());
        outItem.put("columnTypes", outputColTypes);
        this.setSubTypeForOutput(outItem);
        output.add(outItem);
        jsonObject.put("output", output);
        vo.setData(jsonObject);
    }
}
